// Copyright (c) The OpenTofu Authors
// SPDX-License-Identifier: MPL-2.0
// Copyright (c) 2023 HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package azure_vault

import (
	"context"
	"errors"
	"fmt"
	"net/url"
	"os"
	"regexp"
	"strings"

	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
	"github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys"
	"github.com/opentofu/opentofu/internal/backend/remote-state/azure/auth"
	"github.com/opentofu/opentofu/internal/encryption/keyprovider"
)

type keyManagementClientInit func(vaultURL string, credential azcore.TokenCredential) (keyManagementClient, error)

// Can be overridden for test mocking
var newKeyManagementClient keyManagementClientInit = func(vaultURL string, credential azcore.TokenCredential) (keyManagementClient, error) {
	return azkeys.NewClient(vaultURL, credential, nil)
}

type authMethodGetter func(ctx context.Context, authConfig *auth.Config) (auth.AuthMethod, error)

// This can also be overridden for test mocking
var getAuthMethod authMethodGetter = func(ctx context.Context, authConfig *auth.Config) (auth.AuthMethod, error) {
	return auth.GetAuthMethod(ctx, authConfig)
}

// Config
type Config struct {
	UseCLI *bool `hcl:"use_cli,optional"`

	ClientID         string `hcl:"client_id,optional"`
	ClientIDPath     string `hcl:"client_id_file_path,optional"`
	ClientSecret     string `hcl:"client_secret,optional"`
	ClientSecretPath string `hcl:"client_secret_file_path,optional"`

	ClientCert         string `hcl:"client_certificate,optional"`
	ClientCertPassword string `hcl:"client_certificate_password,optional"`
	ClientCertPath     string `hcl:"client_certificate_path,optional"`

	UseOIDC          bool   `hcl:"use_oidc,optional"`
	OIDCToken        string `hcl:"oidc_token,optional"`
	OIDCTokenPath    string `hcl:"oidc_token_file_path,optional"`
	OIDCRequestURL   string `hcl:"oidc_request_url,optional"`
	OIDCRequestToken string `hcl:"oidc_request_token,optional"`

	UseMSI      bool   `hcl:"use_msi,optional"`
	MSIEndpoint string `hcl:"msi_endpoint,optional"`

	UseAKS bool `hcl:"use_aks_workload_identity,optional"`

	MetadataHost   string `hcl:"metadata_host,optional"`
	Environment    string `hcl:"environment,optional"`
	SubscriptionID string `hcl:"subscription_id,optional"`
	TenantID       string `hcl:"tenant_id,optional"`

	// Vault is the key vault URI, of the format "https://myvaultname.vault.azure.net/"
	Vault        string `hcl:"vault_uri"`
	VaultKeyName string `hcl:"vault_key_name"`
	KeyLength    int    `hcl:"key_length"`

	// Symmetric indicates whether the given key provided is a symmetric encryption key or assymmetric.
	// By default this is false, since any key generated by and ordinary Azure is always asymmetric; you have to
	// upload a key or generate it in Azure Key Vault Managed HSM (Hardware Security Module) in order to obtain
	// an AES symmetric key.
	Symmetric bool `hcl:"symmetric,optional"`

	// The size of the symmetric key-encryption key in Azure HSM. This is ignored if `symmetric` is false or unset.
	// Must be the value 128, 196, or 256.
	SymmetricKeySize int `hcl:"symmetric_key_size,optional"`
}

func stringAttrEnvFallback(val string, envs ...string) string {
	return stringAttrDefaultEnvFallback(val, "", envs...)
}

func stringAttrDefaultEnvFallback(val, def string, envs ...string) string {
	if val != "" {
		return val
	}
	for _, env := range envs {
		if envVal, ok := os.LookupEnv(env); ok {
			return envVal
		}
	}
	return def
}

func (c Config) Build() (keyprovider.KeyProvider, keyprovider.KeyMeta, error) {
	// This mirrors the azurerm remote state backend, minus storage-specific auth
	ctx := context.Background()

	algo := azkeys.EncryptionAlgorithmRSAOAEP256
	if c.Symmetric {
		switch c.SymmetricKeySize {
		case 128:
			algo = azkeys.EncryptionAlgorithmA128GCM
		case 192:
			algo = azkeys.EncryptionAlgorithmA192GCM
		case 256:
			algo = azkeys.EncryptionAlgorithmA256GCM
		default:
			return nil, nil, &keyprovider.ErrInvalidConfiguration{Message: "when symmetric is set to true, symmetric_key_size must be given a value of 128, 192, or 256"}
		}
	}

	if c.VaultKeyName == "" {
		return nil, nil, &keyprovider.ErrInvalidConfiguration{Message: "vault_key_name must be provided"}
	}

	if c.Vault == "" {
		return nil, nil, &keyprovider.ErrInvalidConfiguration{Message: "vault_uri must be provided"}
	}

	err := checkKeyNameAndVaultURL(c.VaultKeyName, c.Vault)
	if err != nil {
		return nil, nil, &keyprovider.ErrInvalidConfiguration{Message: fmt.Sprintf("misconfigured key vault name or key name: %s", err.Error())}
	}

	if c.KeyLength < 1 {
		return nil, nil, &keyprovider.ErrInvalidConfiguration{Message: "key_length must be at least 1"}
	}

	environment := stringAttrDefaultEnvFallback(c.OIDCToken, "public", "ARM_ENVIRONMENT")
	metadataHost := stringAttrEnvFallback(c.OIDCToken, "ARM_METADATA_HOST")

	cloudConfig, _, err := auth.CloudConfigFromAddresses(
		ctx,
		environment,
		metadataHost,
	)

	if err != nil {
		return nil, nil, fmt.Errorf("while obtaining Azure cloud configuration: %w", err)
	}

	useCLI := true
	if c.UseCLI != nil {
		useCLI = *c.UseCLI
	}

	authConfig := &auth.Config{
		AzureCLIAuthConfig: auth.AzureCLIAuthConfig{
			CLIAuthEnabled: useCLI,
		},
		ClientSecretCredentialAuthConfig: auth.ClientSecretCredentialAuthConfig{
			ClientID:             stringAttrEnvFallback(c.ClientID, "ARM_CLIENT_ID"),
			ClientIDFilePath:     stringAttrEnvFallback(c.ClientIDPath, "ARM_CLIENT_ID_FILE_PATH"),
			ClientSecret:         stringAttrEnvFallback(c.ClientSecret, "ARM_CLIENT_SECRET"),
			ClientSecretFilePath: stringAttrEnvFallback(c.ClientSecretPath, "ARM_CLIENT_SECRET_FILE_PATH"),
		},
		ClientCertificateAuthConfig: auth.ClientCertificateAuthConfig{
			ClientCertificate:         stringAttrEnvFallback(c.ClientCert, "ARM_CLIENT_CERTIFICATE"),
			ClientCertificatePassword: stringAttrEnvFallback(c.ClientCertPassword, "ARM_CLIENT_CERTIFICATE_PASSWORD"),
			ClientCertificatePath:     stringAttrEnvFallback(c.ClientCertPath, "ARM_CLIENT_CERTIFICATE_PATH"),
		},
		OIDCAuthConfig: auth.OIDCAuthConfig{
			UseOIDC:           c.UseOIDC,
			OIDCToken:         stringAttrEnvFallback(c.OIDCToken, "ARM_OIDC_TOKEN"),
			OIDCTokenFilePath: stringAttrEnvFallback(c.OIDCTokenPath, "ARM_OIDC_TOKEN_FILE_PATH"),
			OIDCRequestURL:    stringAttrEnvFallback(c.OIDCRequestURL, "ARM_OIDC_REQUEST_URL", "ACTIONS_ID_TOKEN_REQUEST_URL"),
			OIDCRequestToken:  stringAttrEnvFallback(c.OIDCRequestToken, "ARM_OIDC_REQUEST_TOKEN", "ACTIONS_ID_TOKEN_REQUEST_TOKEN"),
		},
		MSIAuthConfig: auth.MSIAuthConfig{
			UseMsi:   c.UseMSI,
			Endpoint: c.MSIEndpoint,
		},
		StorageAddresses: auth.StorageAddresses{
			CloudConfig:    cloudConfig,
			SubscriptionID: stringAttrEnvFallback(c.OIDCToken, "ARM_SUBSCRIPTION_ID"),
			TenantID:       stringAttrEnvFallback(c.OIDCToken, "ARM_TENANT_ID"),
		},
		WorkloadIdentityAuthConfig: auth.WorkloadIdentityAuthConfig{
			UseAKSWorkloadIdentity: c.UseAKS,
		},
	}

	authMethod, err := getAuthMethod(ctx, authConfig)
	if err != nil {
		return nil, nil, fmt.Errorf("problem getting an auth method: %w", err)
	}

	authCred, err := authMethod.Construct(ctx, authConfig)
	if err != nil {
		return nil, nil, fmt.Errorf("problem getting auth creds: %w", err)
	}

	client, err := newKeyManagementClient(c.Vault, authCred)
	if err != nil {
		return nil, nil, fmt.Errorf("problem making vault client: %w", err)
	}

	return &keyProvider{
		svc:       client,
		ctx:       ctx,
		keyName:   c.VaultKeyName,
		keyLength: c.KeyLength,
		algo:      algo,
	}, new(keyMeta), nil
}

func checkKeyNameAndVaultURL(keyName, vaultUrl string) error {
	keyNamePattern := regexp.MustCompile(`^[0-9a-zA-Z\-]{1,127}$`)
	vaultPattern := regexp.MustCompile(`^[a-zA-Z][0-9a-zA-Z\-]{1,22}[0-9a-zA-Z]$`)
	hyphenPattern := regexp.MustCompile(`\-\-`)
	if !keyNamePattern.Match([]byte(keyName)) {
		return errors.New("invalid key name: Azure requires a key name consists of 1-127 letters, numbers, or hyphens. See documentation here: https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules#microsoftkeyvault")
	}
	// Break apart the URL into parts
	u, err := url.Parse(vaultUrl)
	if err != nil {
		return fmt.Errorf("invalid key vault URL: %w", err)
	}
	hostname := u.Hostname()
	hostParts := strings.Split(hostname, ".")
	if len(hostParts) == 0 {
		return errors.New("invalid vault host name: no subdomain found")
	}
	vaultName := hostParts[0]
	if !vaultPattern.Match([]byte(vaultName)) {
		return errors.New("invalid key vault name: Azure requires a key vault name consists of 3-24 letters, numbers, and hyphens only. It must start with a letter and cannot end with a hyphen. See documentation here: https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules#microsoftkeyvault")
	}
	if hyphenPattern.Match([]byte(vaultName)) {
		return errors.New("invalid key vault name: Hyphens in a key vault name must be nonconsecutive. See documentation here: https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/resource-name-rules#microsoftkeyvault")
	}
	return nil
}
